{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](pic/11.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](pic/12.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](pic/13.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](pic/14.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](pic/15.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import re\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadDataSet():   #构建简单测试集\n",
    "    postingList=[['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],\n",
    "                 ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],\n",
    "                 ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],\n",
    "                 ['stop', 'posting', 'stupid', 'worthless', 'garbage'],\n",
    "                 ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],\n",
    "                 ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]\n",
    "    classVec = [0,1,0,1,0,1]    #1 is abusive, 0 not\n",
    "    return postingList,classVec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def  createVocabList(dataset):#挑出所有出现过的字符\n",
    "    vocabSet=set([ vocab for x in dataset for vocab in x ])  \n",
    "    return  list(vocabSet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dataset,label=loadDataSet()\n",
    "Dataset,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a=createVocabList(Dataset)\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setOfwords2Vec(vocabList,inputSet):        #向量化，词集模型\n",
    "    returnVec=[0]*len(vocabList)\n",
    "    for vocab in  inputSet:\n",
    "        if vocab in vocabList:\n",
    "            returnVec[vocabList.index(vocab)]=1\n",
    "        else:\n",
    "            print('not exist')\n",
    "    return returnVec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainMat=[]\n",
    "for postinDoc in Dataset:  #构建整个词集\n",
    "    trainMat.append(setOfwords2Vec(a,postinDoc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainNB0(trainMatrix,trainCategory):      #计算每个词的条件概率\n",
    "    numTrainDocs = len(trainMatrix)\n",
    "    numWords = len(trainMatrix[0])\n",
    "    pAbusive = sum(trainCategory)/float(numTrainDocs)\n",
    "    p0Num = np.ones(numWords); p1Num = np.ones(numWords)      #change to ones() ，拉普拉斯平滑\n",
    "    p0Denom = 2.0; p1Denom = 2.0                        #change to 2.0\n",
    "    for i in range(numTrainDocs):\n",
    "        if trainCategory[i] == 1:\n",
    "            p1Num += trainMatrix[i]\n",
    "            p1Denom += sum(trainMatrix[i])\n",
    "        else:\n",
    "            p0Num += trainMatrix[i]\n",
    "            p0Denom += sum(trainMatrix[i])\n",
    "    p1Vect = np.log(p1Num/p1Denom)          #change to log()\n",
    "    p0Vect = np.log(p0Num/p0Denom)          #change to log()\n",
    "    return p0Vect,p1Vect,pAbusive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainNB0(trainMat,label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classifyNB(vec2Classify,p0,p1,pA):  #判别label，贝叶斯公式\n",
    "    p1=sum(vec2Classify*p1)+np.log(pA)\n",
    "    p0=sum(vec2Classify*p0)+np.log(1.0-pA)\n",
    "    if p0>p1:return 0\n",
    "    else: return 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def testingNB(testEntry=['love','my','dalmation']):  #测试功能\n",
    "    listOPosts,listClasses=loadDataSet()\n",
    "    myVocabList=createVocabList(listOPosts)\n",
    "    trainMat=[]\n",
    "    for postinDoc in listOPosts:\n",
    "        trainMat.append(setOfwords2Vec(myVocabList,postinDoc))\n",
    "    p0,p1,pA=trainNB0(np.array(trainMat),np.array(listClasses))\n",
    "    thisDoc=np.array(setOfwords2Vec(myVocabList,testEntry))\n",
    "    print(testEntry,'classified as:',classifyNB(thisDoc,p0,p1,pA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "testingNB()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bagOfwords2VecMN(vocabList,inputSet):\n",
    "    returnVec=[0]*len(vocabList)\n",
    "    for word in inputSet:\n",
    "        if word in vocabList:\n",
    "            returnVec[vocabList.index(word)]+=1\n",
    "    return  returnVec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def textParse(textstring):\n",
    "    listOfTokens=re.split(r'\\W',textstring)\n",
    "    return [word.lower() for word in listOfTokens if len(word)>2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spamTest():   #邮件分类\n",
    "    docList=[];classList=[]  \n",
    "    for i in range(1,27):\n",
    "        wordList=textParse(open(r'data\\email\\spam\\{}.txt'.format(i),encoding='gb18030',errors='ignore').read())\n",
    "        docList.append(wordList)\n",
    "        classList.append(1)\n",
    "        wordList=textParse(open(r'data\\email\\ham\\{}.txt'.format(i),encoding='gb18030',errors='ignore').read())\n",
    "        docList.append(wordList)\n",
    "        classList.append(0)\n",
    "    vocabList=createVocabList(docList)\n",
    "    #创建测试集\n",
    "    testSet=[];testIndex=[random.randint(0,49) for i in range(10)]\n",
    "    for i in testIndex:\n",
    "        testSet.append(docList[i])\n",
    "    trainMat=[];trainClasses=[]\n",
    "    for docIndex in range(50):\n",
    "        trainMat.append(setOfwords2Vec(vocabList,docList[docIndex]))\n",
    "        trainClasses.append(classList[docIndex])\n",
    "    p0,p1,pA=trainNB0(np.array(trainMat),np.array(trainClasses))\n",
    "    #\n",
    "    errorCount=0\n",
    "    for doc in testIndex:\n",
    "        wordVector=setOfwords2Vec(vocabList,docList[doc])\n",
    "        if classifyNB(np.array(wordVector),p0,p1,pA)!=classList[doc]:\n",
    "            errorCount+=1\n",
    "    print('error:%f'%(float(errorCount)/len(testSet)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spamTest()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
